package edu.berkeley.nlp.PCFGLA;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import edu.berkeley.nlp.syntax.SpanTree;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;


public class ParserConstrainer implements Callable{

	StateSetTreeList stateSetTrees;
	Grammar grammar;
	Lexicon lexicon;
	SpanPredictor spanPredictor;
	String outBaseName;
	double threshold;
	String consName;
	boolean keepGoldTreeAlive;
	boolean useHierarchicalParser;
	static int treesPerBlock;
	int myID;
	
	
	public ParserConstrainer(StateSetTreeList stateSetTrees, Grammar grammar, Lexicon lexicon, SpanPredictor spanPredictor, String outBaseName,
			double threshold, boolean keepGoldTreeAlive, int myID, String cons, boolean useHierarchicalParser) {

		this.stateSetTrees = stateSetTrees;
  	this.grammar = grammar;
		this.lexicon = lexicon;
		this.spanPredictor = spanPredictor;
		this.outBaseName = outBaseName;
		this.threshold = threshold;
		this.consName = cons;
		this.keepGoldTreeAlive = keepGoldTreeAlive;
		this.myID = myID;
		this.useHierarchicalParser = useHierarchicalParser;
	}





	public static void main(String[] args) {
		OptionParser optParser = new OptionParser(ConditionalTrainer.Options.class);
		ConditionalTrainer.Options opts = (ConditionalTrainer.Options) optParser.parse(args, false);

		// provide feedback on command-line arguments
		System.out.println("Calling Constrainer with " + optParser.getPassedInOptions());

		String path = opts.path;
//    int lang = opts.lang;
    System.out.println("Loading trees from "+path+" and using language "+opts.treebank);
    String testSetString = opts.section;
    boolean devTestSet = testSetString.equals("dev");
    boolean finalTestSet = testSetString.equals("final");
    boolean trainTestSet = testSetString.equals("train");
    System.out.println(" using "+testSetString+" test set");

    Corpus corpus = new Corpus(path,opts.treebank,opts.trainingFractionToKeep,!trainTestSet);
    List<Tree<String>> testTrees = null;
    if (devTestSet)
    	testTrees = corpus.getDevTestingTrees();
    if (finalTestSet)
    	testTrees = corpus.getFinalTestingTrees();
    if (trainTestSet)
    	testTrees = corpus.getTrainTrees();
    
    testTrees = Corpus.binarizeAndFilterTrees(testTrees, 1,0,
    		opts.maxL, Binarization.RIGHT, false,GrammarTrainer.VERBOSE, opts.markUnaryParents);
    
    if (!devTestSet&&opts.collapseUnaries) System.out.println("Collpasing unary chains.");
    testTrees = Corpus.filterTreesForConditional(testTrees,opts.filterAllUnaries, opts.filterStupidFrickinWHNP,!devTestSet&&opts.collapseUnaries);

    boolean keepGoldAlive = opts.keepGoldTreeAlive || trainTestSet;
    
    String inFileName = opts.inFile;
    System.out.println("Loading grammar from "+inFileName+".");
    ParserData pData = ParserData.Load(inFileName);
    if (pData==null) {
      System.out.println("Failed to load grammar from file "+inFileName+".");
      System.exit(1);
    }
    Grammar grammar = pData.getGrammar();
    grammar.splitRules();
    Lexicon lexicon = pData.getLexicon();
    lexicon.explicitlyComputeScores(grammar.finalLevel);
    SpanPredictor spanPredictor = pData.getSpanPredictor();
    
    if (opts.flattenParameters != 1.0){
    	System.out.println("Flattening parameters with exponent "+opts.flattenParameters+" to reduce overconfidence.");
  		grammar.removeUnlikelyRules(0,opts.flattenParameters);
  		lexicon.removeUnlikelyTags(0,opts.flattenParameters);
    }

    Numberer.setNumberers(pData.getNumbs());
    Numberer tagNumberer = Numberer.getGlobalNumberer("tags");

    StateSetTreeList stateSetTrees = new StateSetTreeList(testTrees, grammar.numSubStates, false, tagNumberer);

    testTrees = null;
    String outBaseName = opts.outFileName;
    double threshold = Math.exp(opts.logT);
    
    int nChunks = opts.nChunks;
    int nTrees = stateSetTrees.size();
    System.out.println("There are "+nTrees+" trees in this set.");
    treesPerBlock = (int)Math.ceil(nTrees/(double)nChunks);
    System.out.println("Will store "+treesPerBlock+" constraints per file, in "+nChunks+" files.");

    System.out.println("All states with posterior probability below "+threshold+" will be pruned.");
    if (keepGoldAlive) System.out.println("But the gold tree will survive!");
    System.out.println("The constraints will be written to "+outBaseName+".");

	  // split the trees into chunks
    StateSetTreeList[] trainingTrees = new StateSetTreeList[nChunks];

	  for (int i=0; i<nChunks; i++){
	  	trainingTrees[i] = new StateSetTreeList();
	  }
	  int block = -1;
	  int inBlock = 0;
	  for (int i=0; i<nTrees; i++){
	  	if (i%treesPerBlock==0) {
	  		block++;
//	  		System.out.println(inBlock);
	  		inBlock = 0;
	  	}
	  	trainingTrees[block].add(stateSetTrees.get(i));
	  	inBlock++;
	  }
	  for (int i=0; i<nChunks; i++){
	  	System.out.println("Process "+i+" has "+trainingTrees[i].size()+" trees.");
	  }
	  stateSetTrees = null;
	  ExecutorService pool = Executors.newFixedThreadPool(nChunks);
		Future[] submits = new Future[nChunks];

		ParserConstrainer thisThreadConstrainer = null;
		if (nChunks == 1)
			thisThreadConstrainer =new ParserConstrainer(trainingTrees[0], grammar, lexicon, spanPredictor, outBaseName, threshold, keepGoldAlive, 0, opts.cons, opts.hierarchicalChart) ;
		else 
		{
		for (int i=0; i<nChunks; i++){
	    ParserConstrainer constrainer = new ParserConstrainer(trainingTrees[i], grammar, lexicon, spanPredictor, outBaseName, threshold, keepGoldAlive, i, opts.cons, opts.hierarchicalChart);
			submits[i] = pool.submit(constrainer);
		}
		

		
		while (true) {
			boolean done = true;
			for (Future task : submits) {
				done &= task.isDone();
			}
			if (done)
				break;
		}
//		pool.shutdown();
		}
  	try {
			PrintWriter outputData = (opts.outputLog==null) ? new PrintWriter(new OutputStreamWriter(System.out)) : new PrintWriter(new OutputStreamWriter(new FileOutputStream(opts.outputLog), "UTF-8"), true);
		
			for (int i = 0; i < nChunks; i++) {
				StringBuilder sb = null;
				if (nChunks == 1)
				{
					sb = thisThreadConstrainer.call();
				}
				else
				{
					 sb = (StringBuilder) submits[i].get();
				}
					outputData.print(sb.toString());
			}
			
			if (opts.outputLog!=null){
				outputData.flush();
				outputData.close();
			}
		} catch (ExecutionException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (InterruptedException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
  	} catch (UnsupportedEncodingException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		} catch (FileNotFoundException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		}
		
    System.out.println("Done computing constraints.");
  }





	/**

	 * @param opts

	 */

	public StringBuilder call(){
    ConstrainedTwoChartsParser parser = (grammar instanceof HierarchicalAdaptiveGrammar) ?
    		new ConstrainedHierarchicalTwoChartParser(grammar, lexicon, spanPredictor, grammar.finalLevel)
    		: new ConstrainedTwoChartsParser(grammar, lexicon, spanPredictor);

    StringBuilder sb = new StringBuilder();
    int recentHistoryIndex = 0;
//    int sentenceNumber = 1;

    boolean[][][][][] recentHistory = new boolean[treesPerBlock][][][][];
    boolean[][][][][] myConstraints = null;
    boolean useCons = consName!=null;

    if (useCons) myConstraints = loadData(consName+"-"+myID+".data");
    boolean[][][][] cons = null;

    for (Tree<StateSet> testTree : stateSetTrees) {
  		List<StateSet> yield = testTree.getYield();
    	List<String> testSentence = new ArrayList<String>(yield.size());

    	for (StateSet el : yield){ testSentence.add(el.getWord()); }
    	sb.append("\n"+(myID*treesPerBlock+recentHistoryIndex+1)+". Length "+testSentence.size());

    	if (useCons) {
      	parser.projectConstraints(myConstraints[recentHistoryIndex]);
      	cons = myConstraints[recentHistoryIndex];
      }

      Tree<StateSet> sTree = null;
      if (keepGoldTreeAlive) {
//      	System.out.println("keeping gold tree alive");
      	sTree = testTree;
      }
    	boolean[][][][] possibleStates = parser.getPossibleStates(testSentence,sTree,threshold,cons,sb);

    	if (useCons) myConstraints[recentHistoryIndex] = null;
      recentHistory[recentHistoryIndex++] = possibleStates;

      if (recentHistoryIndex%1000==0) System.out.print(".");
//      sentenceNumber++;
//    	if (recentHistoryIndex>0 && (recentHistoryIndex % treesPerBlock == 0)) {
//    		String fileName = outBaseName+"-"+blockIndex+".data";
//    		saveData(recentHistory, fileName);
//        blockIndex++;
//    		if (useCons && sentenceNumber<nTrees)    			myConstraints = loadData(consName+"-"+blockIndex+".data");
//    		recentHistory = new boolean[treesPerBlock][][][][];
//    		recentHistoryIndex = 0;
//    	}
    }

//  	if (recentHistoryIndex!=0) {
  		String fileName = outBaseName+"-"+myID+".data";
  		saveData(recentHistory, fileName);
//  	}

  	return sb;
	}

	

	

	public static boolean saveData(boolean[][][][][] data, String fileName){
    try {
      //here's some code from online; it looks good and gzips the output!
      //  there's a whole explanation at http://www.ecst.csuchico.edu/~amk/foo/advjava/notes/serial.html
      // Create the necessary output streams to save the scribble.
      FileOutputStream fos = new FileOutputStream(fileName); // Save to file
      GZIPOutputStream gzos = new GZIPOutputStream(fos); // Compressed
      ObjectOutputStream out = new ObjectOutputStream(gzos); // Save objects
      out.writeObject(data); // Write the mix of grammars
      out.flush(); // Always flush the output.
      out.close(); // And close the stream.
      gzos.close();
      fos.close();
    } catch (IOException e) {
      System.out.println("IOException: "+e);
      return false;
    }
    return true;
  }

	
  public static boolean isGoldReachable(SpanTree<String> gold, List[][] possibleStates, Numberer tagNumberer){

  	boolean reachable = true;

		reachable = possibleStates[gold.getStart()][gold.getEnd()].contains(tagNumberer.number(gold.getLabel()));

  	if (reachable && (!gold.isLeaf())){

			for (SpanTree<String> child : gold.getChildren()){

				reachable = isGoldReachable(child, possibleStates, tagNumberer);

				if (!reachable) return false;

			}

		}

  	if (!reachable) {

  		System.out.println("Cannot reach state "+gold.getLabel()+" spanning from "+gold.getStart()+" to "+gold.getEnd()+".");

  	}

  	return reachable;

  }

	public static SpanTree<String> convertToSpanTree(Tree<String> tree){

		if (tree.isPreTerminal()){

			return new SpanTree<String>(tree.getLabel());

		}

		if (tree.getChildren().size()>2) System.out.println("Binarize properly first!");

		SpanTree<String> spanTree = new SpanTree<String>(tree.getLabel());

		List<SpanTree<String>> spanChildren = new ArrayList<SpanTree<String>>();
		for (Tree<String> child : tree.getChildren()){

			SpanTree<String> spanChild = convertToSpanTree(child);

			spanChildren.add(spanChild);

		}

		spanTree.setChildren(spanChildren);

		return spanTree;

	}



  public static boolean[][][][][] loadData(String fileName) {
  	boolean[][][][][] data = null;
    try {
      FileInputStream fis = new FileInputStream(fileName); // Load from file
      GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
      ObjectInputStream in = new ObjectInputStream(gzis); // Load objects
      data = (boolean[][][][][])in.readObject(); // Read the mix of grammars
      in.close(); // And close the stream.
      gzis.close();
      fis.close();
    } catch (IOException e) {
      System.out.println("IOException\n"+e);
      return null;
    } catch (ClassNotFoundException e) {
      System.out.println("Class not found!");
      return null;
    }
    return data;
  }



}

